package sshutils

import (
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net"
	"os"
	"path"
	"time"

	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
)

var BuffSize int = 1024 * 1024

type Cli struct {
	IP       string //IP地址
	Username string //用户名
	Password string //密码
	Port     int    //端口号

	config    ssh.ClientConfig //配置
	SshClient *ssh.Client      //ssh客户端
	SftpCient *sftp.Client     //sftp客户端
}

type SshRecord struct {
	Command string //执行命令
	Result  string //执行结果
	Err     error  //异常
}

type SftpReq struct {
	LocalPath  string //本地路径
	RemotePath string //远程路径
	Overwrite  bool   // 覆盖
}

type SftpRecord struct {
	LocalPath  string //本地路径
	RemotePath string //远程路径
	Err        error  //异常
}

func newSftpRecord(req SftpReq, err error) SftpRecord {
	return SftpRecord{LocalPath: req.LocalPath, RemotePath: req.RemotePath, Err: err}
}

// 创建连接器
func NewCli(ip string, username string, password string, port ...int) (*Cli, error) {
	cli := new(Cli)
	cli.IP = ip
	cli.Username = username
	cli.Password = password
	if len(port) <= 0 {
		cli.Port = 22
	} else {
		cli.Port = port[0]
	}
	cli.config = ssh.ClientConfig{
		Config: ssh.Config{},
		User:   cli.Username,
		Auth:   []ssh.AuthMethod{ssh.Password(cli.Password)},
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
			return nil
		},
		Timeout: 10 * time.Second,
	}
	Reconnect(cli)
	return cli, nil
}

// 重联
func Reconnect(cli *Cli) error {
	var err error
	addr := fmt.Sprintf("%s:%d", cli.IP, cli.Port)
	cli.SshClient, err = ssh.Dial("tcp", addr, &cli.config)
	if err != nil {
		return err
	}
	cli.SftpCient, err = sftp.NewClient(cli.SshClient)
	if err != nil {
		return err
	}
	return nil
}

// 执行命令
// 注意：这个执行命令和go的os.exec一个尿性，不能获取到真实的错误信息，异常只能用于判断是否正确执行了你的命令
func (cli Cli) Run(command string) SshRecord {
	var session *ssh.Session
	r := SshRecord{Command: command}
	session, r.Err = cli.SshClient.NewSession()
	if r.Err != nil {
		return r
	}
	defer session.Close()
	var buf []byte
	buf, r.Err = session.CombinedOutput(command)
	if r.Err != nil {
		return r
	}
	r.Result = string(buf)
	return r
}

// 使用sftp上传本地文件或文件夹到远程
func (cli Cli) Upload(req SftpReq) []SftpRecord {
	localPathInfo, err := os.Stat(req.LocalPath)
	if err != nil {
		return []SftpRecord{newSftpRecord(req, err)}
	}
	remotePathInfo, err := cli.SftpCient.Lstat(req.RemotePath)
	if err == nil { //没有异常，远程存在同名的文件或文件夹
		if !req.Overwrite {
			err = errors.New(req.RemotePath + "已存在(选择了非覆写模式: SftpReq.Overwrite = false)")
			return []SftpRecord{newSftpRecord(req, err)}
		}
		// 覆盖的操作，删除远端
		if remotePathInfo.IsDir() {
			cli.SftpCient.RemoveDirectory(req.RemotePath)
		} else {
			cli.SftpCient.Remove(req.RemotePath)
		}
	}
	// 远程父级目录处理
	err = cli.SftpCient.MkdirAll(path.Dir(req.RemotePath))
	if err != nil {
		return []SftpRecord{newSftpRecord(req, err)}
	}
	if localPathInfo.IsDir() {
		return cli.uploadDir(req)
	}
	return []SftpRecord{cli.uploadFile(req)}
}

func (cli Cli) uploadDir(req SftpReq) []SftpRecord {
	err := cli.SftpCient.MkdirAll(req.RemotePath)
	if err != nil {
		return []SftpRecord{newSftpRecord(req, err)}
	}
	fs, err := ioutil.ReadDir(req.LocalPath)
	if err != nil {
		return []SftpRecord{newSftpRecord(req, err)}
	}
	rs := []SftpRecord{}
	for _, f := range fs {
		lp := path.Join(req.LocalPath, f.Name())
		rp := path.Join(req.RemotePath, f.Name())
		rst := cli.Upload(SftpReq{LocalPath: lp, RemotePath: rp, Overwrite: req.Overwrite})
		rs = append(rs, rst...)
	}
	return rs
}

func (cli Cli) uploadFile(req SftpReq) SftpRecord {
	r := newSftpRecord(req, nil)
	srcFile, err := os.Open(req.LocalPath)
	if err != nil {
		r.Err = err
		return r
	}
	destFile, err := cli.SftpCient.Create(req.RemotePath)
	if err != nil {
		println("err:" + err.Error())
		r.Err = err
		return r
	}
	defer func() {
		_ = srcFile.Close()
		_ = destFile.Close()
	}()
	buf := make([]byte, BuffSize)
	for {
		n, err := srcFile.Read(buf)
		if err != nil {
			if err != io.EOF {
				r.Err = err
				return r
			} else {
				break
			}
		}
		_, _ = destFile.Write(buf[:n])
	}
	return r
}

// 使用sftp下载远程文件或文件夹到本地
func (cli Cli) Download(req SftpReq) []SftpRecord {
	remotePathInfo, err := cli.SftpCient.Lstat(req.RemotePath)
	if err != nil { //异常，远程不存在同名的文件或文件夹
		return []SftpRecord{newSftpRecord(req, err)}
	}
	_, err = os.Stat(req.LocalPath)
	if err == nil { //没有异常，本地存在同名的文件或文件夹
		if !req.Overwrite {
			err = errors.New(req.LocalPath + "已存在(选择了非覆写模式: SftpReq.Overwrite = false)")
			return []SftpRecord{newSftpRecord(req, err)}
		}
		// 覆盖的操作，删除本地
		os.RemoveAll(req.LocalPath)
	}
	if remotePathInfo.IsDir() {
		return cli.downloadDir(req)
	}
	return []SftpRecord{cli.downloadFile(req)}
}

func (cli Cli) downloadDir(req SftpReq) []SftpRecord {
	err := os.MkdirAll(req.LocalPath, os.ModePerm)
	if err != nil {
		return []SftpRecord{newSftpRecord(req, err)}
	}
	fs, err := cli.SftpCient.ReadDir(req.RemotePath) // 获取远程文件列表
	if err != nil {
		return []SftpRecord{newSftpRecord(req, err)}
	}
	rs := []SftpRecord{}
	for _, f := range fs {
		lp := path.Join(req.LocalPath, f.Name())
		rp := path.Join(req.RemotePath, f.Name())
		rst := cli.Download(SftpReq{LocalPath: lp, RemotePath: rp, Overwrite: req.Overwrite})
		rs = append(rs, rst...)
	}
	return rs
}

func (cli Cli) downloadFile(req SftpReq) SftpRecord {
	r := newSftpRecord(req, nil)
	srcFile, _ := cli.SftpCient.Open(req.RemotePath) //远程
	dstFile, _ := os.Create(req.LocalPath)           //本地
	defer func() {
		_ = srcFile.Close()
		_ = dstFile.Close()
	}()
	if _, r.Err = srcFile.WriteTo(dstFile); r.Err != nil {
		return r
	}
	return r
}
